"""
部分旋转 [:,:,20]
对 loss的研究，extraction_rate, c_init,c_fin
"""
import argparse
from collections import defaultdict

from torch import optim
from torch.utils.data import DataLoader
import wandb
from tqdm import trange

from disvae.models import encoders, decoders
from disvae.models.anneal import *
from disvae.models.losses import BtcvaeLoss, get_loss_s, _kl_normal_loss, _reconstruction_loss
from exps.models import *
from exps.visualize import *
from exps.patterns import data_generator
import numpy as np
from utils.helpers import set_seed


hyperparameter_defaults = dict(
    batch_size=256,
    learning_rate=5e-4,
    epochs=141,
    beta=20,
    dimension=6,
    img_id=0,
    random_seed=210,
    file=__file__.split('/')[-1]
)

parser = argparse.ArgumentParser()
for key, value in hyperparameter_defaults.items():
    parser.add_argument(f'--{key}', default=value, type=type(value))
args = parser.parse_args()

wandb.init(project="exps", config=args)
config = args
set_seed(config.random_seed)


def train(dl, model, epochs):
    opt = optim.AdamW(model.parameters(), config.learning_rate)
    inc = 1.5 / len(dl) / epochs
    C = 0
    for e in trange(epochs):
        storer = defaultdict(list)

        for img, label in dl:
            bs = len(img)
            img = img.view(-1, 1, 64, 64).cuda()
            mean, logvar = encoder(img)
            latent_sample = reparameterize(mean, logvar)

            latent_kl = 0.5 * (-1 - logvar + mean.pow(2) + logvar.exp()).mean(dim=0)
            kl = config.beta * latent_kl.sum()
            # torch.where(latent_kl<C,config.beta *latent_kl,1000*latent_kl).sum()
            C += inc

            recon_list = []
            for i, decoder in enumerate(decoder_list):
                recon_list.append(decoder(latent_sample[:, 3 * i:3 * i + 3]))

            recon_loss = _reconstruction_loss(img, recon_list[0])
            for i in range(1, len(recon_list)):
                recon_loss = recon_loss + _reconstruction_loss(img, recon_list[0])
            loss = kl + recon_loss

            opt.zero_grad()
            loss.backward()
            opt.step()

            for i in range(dim):
                storer[f'kl_{i}'].append(latent_kl[i].item())

        for k, v in storer.items():
            if isinstance(v, list):
                storer[k] = np.mean(v)

        if (e + 1) % 40 == 0:
            storer['C'] = C
            for i in range(dim):
                storer[f'z_{i}'] = wandb.Histogram(latent_sample[:, i].data.cpu().numpy())
            storer['recon'] = wandb.Image(recon[0, 0])
            storer['img'] = wandb.Image(img[0, 0])

            kl_loss = torch.Tensor([v for k, v in storer.items() if 'kl_' in k])
            _, index = kl_loss.sort(descending=True)

            preds, targets = get_preds(model, dl)
            preds, targets = preds[:1000, index[:3]], targets[:1000]
            fig = gt_vs_latent(preds.cpu(), targets.cpu())
            storer['correlated'] = wandb.Image(fig)

            model.cpu()
            model.eval()
            fig = plt_sample_traversal(None, model, 7, index, r=2)
            storer['traversal'] = wandb.Image(fig)
            model.cuda()
            model.train()

        # storer['epoch'] = e
        wandb.log(storer, step=e, sync=True)
        # plt.close()
    return storer


# generate data
img_id = config.img_id
img = torch.load(f'patterns/{img_id}.pat')

imgs, labels = data_generator.gen_rotation(img, img.shape)

epochs = config.epochs
dim = config.dimension

dataset = imgs.reshape(-1, 1, 64, 64)
nlabels = labels.reshape(-1, 3)

encoder = encoders.EncoderBurgess((1, 64, 64), dim)
decoder_list = [decoders.DecoderBurgess((1, 64, 64), 3) for _ in range(2)]
model = torch.nn.Sequential(*decoder_list[1:])
model.add_module('encoder', encoder)
model.add_module('decoder', decoder_list[0])

model.train()
model.cuda()

dl = DataLoader(list(zip(dataset, nlabels)), pin_memory=True,
                num_workers=0, batch_size=config.batch_size, shuffle=True)

storer = train(dl, model, epochs)

kl_loss = torch.Tensor([v for k, v in storer.items() if 'kl_' in k])
_, index = kl_loss.sort(descending=True)
print(index)
model.eval()
model.cpu()

preds, targets = get_preds(model, dl)
points = preds[:, index[:3]].cpu()
storer['point_cloud'] = wandb.Object3D(points.numpy())

fig = plt_sample_traversal(None, model, 7, index, r=2)
storer['traversal'] = wandb.Image(fig)

wandb.log(storer)
plt.show()
